import time
import kgdlg.utils.misc_utils as utils
import torch
from torch.autograd import Variable
import random
import os
import sys
import math
class Statistics(object):
    """
    Train/validate loss statistics.
    """
    def __init__(self, seq_loss=0, kld_loss=0, n_words=0, n_correct=0):
        self.seq_loss = seq_loss
        self.kld_loss = kld_loss
        self.n_words = n_words
        self.n_correct = n_correct
        self.n_src_words = 0
        self.start_time = time.time()

    def update(self, stat):
        self.seq_loss += stat.seq_loss
        self.kld_loss += stat.kld_loss
        self.n_words += stat.n_words
        self.n_correct += stat.n_correct

    def ppl(self):
        return utils.safe_exp(self.seq_loss / self.n_words)

    def accuracy(self):
        return 100 * (self.n_correct / self.n_words)

    def elapsed_time(self):
        return time.time() - self.start_time

    def print_out(self, epoch, batch, n_batches, start):
        t = self.elapsed_time()

        out_info = ("Epoch %2d, %5d/%5d| acc: %6.2f| ppl: %6.2f| " + \
               "kld: %6.4f| %4.0f s elapsed") % \
              (epoch, batch, n_batches,
               self.accuracy(),
               self.ppl(),
               self.kld_loss,
               time.time() - self.start_time)

        print(out_info)
        sys.stdout.flush()


class PriorStatistics(object):
    """
    Train/validate loss statistics.
    """
    def __init__(self, cluster_loss=0, keyword_loss=0):
        self.cluster_loss = cluster_loss
        self.keyword_loss = keyword_loss
        self.start_time = time.time()

    def update(self, stat):
        self.cluster_loss += stat.cluster_loss
        self.keyword_loss += stat.keyword_loss


    # def accuracy(self):
    #     return 100 * (self.n_correct / self.n_words)

    def elapsed_time(self):
        return time.time() - self.start_time

    def print_out(self, epoch, batch, n_batches, start):
        t = self.elapsed_time()

        out_info = ("Epoch %2d, %5d/%5d| c: %6.2f| kw: %6.2f" + \
               "| %4.0f s elapsed") % \
              (epoch, batch, n_batches,
               self.cluster_loss,
               self.keyword_loss,
               time.time() - self.start_time)

        print(out_info)
        sys.stdout.flush()

class Trainer(object):
    def __init__(self, opt, model, train_iter,
                 train_loss, optim, lr_scheduler):

        self.opt = opt
        self.model = model
        self.train_iter = train_iter
        self.train_loss = train_loss
        self.optim = optim
        self.lr_scheduler = lr_scheduler

        # Set model in training mode.
        self.model.train()       

        self.global_step = 0
        self.step_epoch = 0

    def update(self, batch):
        self.model.zero_grad()
        src_inputs = batch.src[0]
        src_lengths = batch.src[1].tolist()
        tgt_inputs = batch.tgt[:-1]
        bow_inputs = batch.bow
        outputs, attn = self.model(src_inputs,tgt_inputs,src_lengths)
        stats = self.train_loss.compute_train_loss(batch, outputs)

        self.optim.step()
        return stats

    def train(self, epoch, report_func=None):
        """ Called for each epoch to train. """
        total_stats = Statistics()
        report_stats = Statistics()
         
        for batch in self.train_iter:
            self.global_step += 1
            step_batch = self.train_iter.iterations
            stats = self.update(batch)
            
            report_stats.update(stats)
            total_stats.update(stats)

            if report_func is not None:
                report_stats = report_func(self.global_step,
                        epoch, step_batch, len(self.train_iter),
                        total_stats.start_time, self.optim.lr, report_stats) 


        return total_stats           


    def save_per_epoch(self, epoch, out_dir):
        f = open(os.path.join(out_dir,'checkpoint'),'w')
        f.write('latest_checkpoint:checkpoint_epoch%d.pkl'%(epoch))
        f.close()
        self.model.save_checkpoint(epoch, self.opt,
                    os.path.join(out_dir,"checkpoint_epoch%d.pkl"%(epoch)))
        
        
    def epoch_step(self, epoch, out_dir):
        """ Called for each epoch to update learning rate. """
        # self.optim.updateLearningRate(ppl, epoch) 
        # self.lr_scheduler.step()
        self.save_per_epoch(epoch, out_dir)


class PriorTrainer(object):
    def __init__(self, opt, model, train_iter,
                 train_loss, optim, lr_scheduler):

        self.opt = opt
        self.model = model
        self.train_iter = train_iter
        self.train_loss = train_loss
        self.optim = optim
        self.lr_scheduler = lr_scheduler

        # Set model in training mode.
        self.model.train()       

        self.global_step = 0
        self.step_epoch = 0

    def update(self, batch):
        self.model.zero_grad()
        src_inputs = batch.src[0]
        src_lengths = batch.src[1].tolist()
        tgt_inputs = batch.tgt[0][:-1]
        tgt_lengths = batch.tgt[0]-1
        tgt_lengths = tgt_lengths.tolist()
        bow_inputs = batch.bow[0]
        clusters = batch.cluster
        cluster_logits, keyword_logits = self.model(src_inputs,src_lengths,clusters,bow_inputs)

        

        
        stats = self.train_loss.compute_loss(batch,cluster_logits, keyword_logits)
        self.optim.step()
        return stats

    def train(self, epoch, report_func=None):
        """ Called for each epoch to train. """
        total_stats = PriorStatistics()
        report_stats = PriorStatistics()
         
        for batch in self.train_iter:
            self.global_step += 1
            step_batch = self.train_iter.iterations
            stats = self.update(batch)
            
            report_stats.update(stats)
            total_stats.update(stats)

            if report_func is not None:
                report_stats = report_func(self.global_step,
                        epoch, step_batch, len(self.train_iter),
                        total_stats.start_time, self.optim.lr, report_stats) 


        return total_stats           


    def save_per_epoch(self, epoch, out_dir):
        f = open(os.path.join(out_dir,'checkpoint'),'w')
        f.write('latest_checkpoint:checkpoint_epoch%d.pkl'%(epoch))
        f.close()
        self.model.save_checkpoint(epoch, self.opt,
                    os.path.join(out_dir,"checkpoint_epoch%d.pkl"%(epoch)))
        
        
    def epoch_step(self, epoch, out_dir):
        """ Called for each epoch to update learning rate. """
        self.save_per_epoch(epoch, out_dir)


class RecogTrainer(object):
    def __init__(self, opt, model, train_iter,
                 train_loss, optim, lr_scheduler):

        self.opt = opt
        self.model = model
        self.train_iter = train_iter
        self.train_loss = train_loss
        self.optim = optim
        self.lr_scheduler = lr_scheduler

        # Set model in training mode.
        self.model.train()       

        self.global_step = 0
        self.step_epoch = 0

    def update(self, batch):
        self.model.zero_grad()
        src_inputs = batch.src[0]
        src_lengths = batch.src[1].tolist()
        tgt_inputs = batch.tgt[0][:-1]
        tgt_lengths = batch.tgt[1]-1
        tgt_lengths = tgt_lengths.tolist()
        
        clusters = batch.cluster
        cluster_logits, keyword_logits = self.model(src_inputs,src_lengths,
                                                    tgt_inputs, tgt_lengths, 
                                                    clusters)

        

        
        stats = self.train_loss.compute_loss(batch,cluster_logits, keyword_logits)
        self.optim.step()
        return stats

    def train(self, epoch, report_func=None):
        """ Called for each epoch to train. """
        total_stats = PriorStatistics()
        report_stats = PriorStatistics()
         
        for batch in self.train_iter:
            self.global_step += 1
            step_batch = self.train_iter.iterations
            stats = self.update(batch)
            
            report_stats.update(stats)
            total_stats.update(stats)

            if report_func is not None:
                report_stats = report_func(self.global_step,
                        epoch, step_batch, len(self.train_iter),
                        total_stats.start_time, self.optim.lr, report_stats) 


        return total_stats           


    def save_per_epoch(self, epoch, out_dir):
        f = open(os.path.join(out_dir,'checkpoint'),'w')
        f.write('latest_checkpoint:checkpoint_epoch%d.pkl'%(epoch))
        f.close()
        self.model.save_checkpoint(epoch, self.opt,
                    os.path.join(out_dir,"checkpoint_epoch%d.pkl"%(epoch)))
        
        
    def epoch_step(self, epoch, out_dir):
        """ Called for each epoch to update learning rate. """
        self.save_per_epoch(epoch, out_dir)



class JointTrainer(object):
    def __init__(self, opt, model, train_iter,
                 train_loss, optim, lr_scheduler,set_multiple_ans):

        self.opt = opt
        self.model = model
        self.train_iter = train_iter

        self.train_loss = train_loss
        self.optim = optim
        self.lr_scheduler = lr_scheduler

        # Set model in training mode.
        self.model.train()       

        self.global_step = 0
        self.step_epoch = 0
        self.set_multiple_ans = set_multiple_ans


    def update(self, batch):
        self.model.zero_grad()
        src_inputs = batch.src[0]
        src_lengths = batch.src[1].tolist()
        bow_inputs = batch.bow[0]
        # 在此时，未用于监督模型的batch的bow
        # 此时的type(bow_inputs) = tensor,shape(bow_inputs) =[65, 38]

        tgt_inputs_list = []
        tgt_length_list = []
        bacth_dict = vars(batch)

        number_tgt = 0
        for name in bacth_dict:
            if 'tgt' in name:
                number_tgt = number_tgt + 1

        for i in range(number_tgt):
            tgt_name = 'tgt'+ str(i)
            similar_key_name = 'similar_key'+ str(i)
            peculiarity_key_name = '﻿peculiarity_key' + str(i)

            tgt_inputs =  bacth_dict[tgt_name][0][:-1]
            tgt_lengths = bacth_dict[tgt_name][1] - 1

            similar_key_name = bacth_dict[similar_key_name][0][:-1]
            peculiarity_key_name = bacth_dict[peculiarity_key_name][0][:-1]
            print(tgt_inputs,similar_key_name,peculiarity_key_name)
        os._exit(0)





        if self.model.train_mode == 110:
            half = len(tgt_inputs_list)/2
            half = int(half)
            output_list, kld_loss, bow_logit = self.model(src_inputs, src_lengths,
                                                                tgt_inputs_list[:half], tgt_length_list[:half])

            neg_output_list, neg_kld_loss, neg_bow_logits = self.model(src_inputs, src_lengths,
                                                                tgt_inputs_list[half:], tgt_length_list[half:])
            for neg_output in neg_output_list:
                output_list.append((-0.1)*neg_output)
            kld_loss = kld_loss + (-0.1)* neg_kld_loss
            bow_logit = bow_logit + (-0.1)*neg_bow_logits

            stats = self.train_loss.compute_loss( batch, output_list, kld_loss, bow_logit)

        else:
            output_list, kld_loss, bow_logits_list = self.model(src_inputs, src_lengths,
                                                                tgt_inputs_list, tgt_length_list)
            stats = self.train_loss.compute_loss(batch, output_list, kld_loss, bow_logits_list)
        self.optim.step()
        return stats

    def train(self, epoch, report_func=None):
        """ Called for each epoch to train. """
        total_stats = Statistics()
        report_stats = Statistics()

        for batch in self.train_iter:
            # print("也许是batch的batch_size的问题",batch.batch_size)
            self.global_step += 1
            step_batch = self.train_iter.iterations
            stats = self.update(batch)

            report_stats.update(stats)
            total_stats.update(stats)

            if report_func is not None:
                report_stats = report_func(self.global_step,
                        epoch, step_batch, len(self.train_iter),
                        total_stats.start_time, self.optim.lr, report_stats)

        return total_stats


    def save_per_epoch(self, epoch, out_dir):
        f = open(os.path.join(out_dir,'checkpoint'),'w')
        f.write('latest_checkpoint:checkpoint_epoch%d.pkl'%(epoch))
        f.close()
        self.model.save_checkpoint(epoch, self.opt,
                    os.path.join(out_dir,"checkpoint_epoch%d.pkl"%(epoch)))
        
        
    def epoch_step(self, epoch, out_dir):
        """ Called for each epoch to update learning rate. """
        print("进行保存epoch")
        self.save_per_epoch(epoch, out_dir)